/*
 * AIEngine a new generation network intrusion detection system.
 *
 * Copyright (C) 2013-2023  Luis Campo Giralte
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA  02110-1301, USA.
 *
 * Written by Luis Campo Giralte <luis.camp0.2009@gmail.com>
 *
 */
#include "QuicProtocol.h"
#include <iomanip>

namespace aiengine {

QuicProtocol::QuicProtocol():
	Protocol("Quic", IPPROTO_UDP) {}

bool QuicProtocol::check(const Packet &packet){

	int length = packet.getLength();

	if (length >= header_size) {
		// Quic is started by the clients so just the destination port is checked
		// and the first packet read must be one
		if ((packet.getDestinationPort() == 80)or
			(packet.getDestinationPort() == 443)) {

			setHeader(packet.getPayload());
			if ((header_->flags == 0x0D)and(header_->pkt_number == 0x01)) {
				++total_valid_packets_;
				return true;
			}
		}
	}
	++total_invalid_packets_;
	return false;
}

void QuicProtocol::setDynamicAllocatedMemory(bool value) {

        info_cache_->setDynamicAllocatedMemory(value);
        host_cache_->setDynamicAllocatedMemory(value);
        ua_cache_->setDynamicAllocatedMemory(value);
}

bool QuicProtocol::isDynamicAllocatedMemory() const {

        return info_cache_->isDynamicAllocatedMemory();
}

uint64_t QuicProtocol::getCurrentUseMemory() const {

        uint64_t mem = sizeof(QuicProtocol);

        mem += info_cache_->getCurrentUseMemory();
        mem += host_cache_->getCurrentUseMemory();
        mem += ua_cache_->getCurrentUseMemory();

        return mem;
}

uint64_t QuicProtocol::getAllocatedMemory() const {

        uint64_t mem = sizeof(QuicProtocol);

        mem += info_cache_->getAllocatedMemory();
        mem += host_cache_->getAllocatedMemory();
        mem += ua_cache_->getAllocatedMemory();

        return mem;
}

uint64_t QuicProtocol::getTotalAllocatedMemory() const {

        uint64_t mem = getAllocatedMemory();

        mem += compute_memory_used_by_maps();

        return mem;
}

uint64_t QuicProtocol::compute_memory_used_by_maps() const {

        uint64_t bytes = 0;

        std::for_each (host_map_.begin(), host_map_.end(), [&bytes] (PairStringCacheHits const &f) {
                bytes += f.first.size();
        });
        std::for_each (ua_map_.begin(), ua_map_.end(), [&bytes] (PairStringCacheHits const &f) {
                bytes += f.first.size();
        });
        return bytes;
}

uint32_t QuicProtocol::getTotalCacheMisses() const {

        uint32_t miss = 0;

        miss = info_cache_->getTotalFails();
        miss += host_cache_->getTotalFails();
        miss += ua_cache_->getTotalFails();

        return miss;
}

void QuicProtocol::releaseCache() {

        if (FlowManagerPtr fm = flow_mng_.lock(); fm) {
                auto ft = fm->getFlowTable();

                std::ostringstream msg;
                msg << "Releasing " << name() << " cache";

                infoMessage(msg.str());

                uint64_t total_cache_bytes_released = compute_memory_used_by_maps();
                uint64_t total_bytes_released_by_flows = 0;
                uint64_t total_cache_save_bytes = 0;
                uint32_t release_flows = 0;
                uint32_t release_hosts = host_map_.size();
                uint32_t release_uas = ua_map_.size();

                for (auto &flow: ft) {
                        if (SharedPointer<QuicInfo> info = flow->getQuicInfo(); info) {
                                total_bytes_released_by_flows += sizeof(info);

                                flow->layer7info.reset();
                                ++release_flows;
                                info_cache_->release(info);
                        }
                }
                // Some entries can be still on the maps and needs to be
                // retrieve to their existing caches
                for (auto &entry: host_map_) {
			total_cache_save_bytes += entry.second.sc->size() * (entry.second.hits - 1);
                        releaseStringToCache(host_cache_, entry.second.sc);
		}
                host_map_.clear();

                for (auto &entry: ua_map_) {
			total_cache_save_bytes += entry.second.sc->size() * (entry.second.hits - 1);
                        releaseStringToCache(ua_cache_, entry.second.sc);
		}
                ua_map_.clear();

		data_time_ = boost::posix_time::microsec_clock::local_time();

                msg.str("");
                msg << "Release " << release_hosts << " host names, ";
		msg << release_uas << " user agents, " << release_flows << " flows";
		computeMemoryUtilization(msg, total_cache_bytes_released, total_bytes_released_by_flows, total_cache_save_bytes);
		infoMessage(msg.str());
        }
}

void QuicProtocol::releaseFlowInfo(Flow *flow) {

        if (auto info = flow->getQuicInfo(); info)
                info_cache_->release(info);
}

void QuicProtocol::setDomainNameManager(const SharedPointer<DomainNameManager> &dm) {

        if (domain_mng_)
                domain_mng_->setPluggedToName("");

        if (dm) {
                domain_mng_ = dm;
                domain_mng_->setPluggedToName(name());
        } else {
                domain_mng_.reset();
        }
}

void QuicProtocol::attach_host_name(QuicInfo *info, const boost::string_ref &name) {

        if (!info->server_name) {
                if (StringMap::iterator it = host_map_.find(name); it != host_map_.end()) {
                        ++(it->second).hits;
                        info->server_name = (it->second).sc;
                } else {
                        if (SharedPointer<StringCache> name_ptr = host_cache_->acquire(); name_ptr) {
                                name_ptr->name(name.data(), name.length());
                                info->server_name = name_ptr;
                                host_map_.insert(std::make_pair(name_ptr->name(), name_ptr));
                        }
                }
        }
}

void QuicProtocol::attach_user_agent(QuicInfo *info, const boost::string_ref &name) {

        if (!info->user_agent) {
                if (StringMap::iterator it = ua_map_.find(name); it != ua_map_.end()) {
                        ++(it->second).hits;
                        info->user_agent = (it->second).sc;
                } else {
                        if (SharedPointer<StringCache> name_ptr = ua_cache_->acquire(); name_ptr) {
                                name_ptr->name(name.data(), name.length());
                                info->user_agent = name_ptr;
                                ua_map_.insert(std::make_pair(name_ptr->name(), name_ptr));
                        }
                }
        }
}

void QuicProtocol::handle_client_hello(QuicInfo *info, const uint8_t *data, int length, int tags) {

	int offset = 0;
	int tagn = 0;
	int prev_off = 0;
	int host_soff = 0;
	int host_eoff = 0;
	int ua_soff = 0, ua_eoff = 0;

	const quic_tag_value *tag = reinterpret_cast<const quic_tag_value*>(&data[offset]);

	prev_off = tag->length;

	while ((tagn < tags)and(tagn < 32)) {
		if ((tag->tag[0] == 'S')and(tag->tag[1] == 'N')and(tag->tag[2] == 'I')) {
			host_soff = prev_off;
			host_eoff = tag->length;
		} else if ((tag->tag[0] == 'U')and(tag->tag[1] == 'A')and(tag->tag[2] == 'I')and(tag->tag[3] == 'D')) {
			ua_soff = prev_off;
			ua_eoff = tag->length;
		}
		prev_off = tag->length;
		tag = reinterpret_cast<const quic_tag_value*>(&data[offset]);
		offset += sizeof(quic_tag_value);
		++tagn;
	}

	if (host_soff > 0) {
                boost::string_ref host_name(reinterpret_cast<const char*>(&data[offset + host_soff]), host_eoff - host_soff);
                attach_host_name(info, host_name);
	}

	if (ua_soff > 0) {
                boost::string_ref ua_name(reinterpret_cast<const char*>(&data[offset + ua_soff]), ua_eoff - ua_soff);
		attach_user_agent(info, ua_name);
	}
}

void QuicProtocol::processFlow(Flow *flow) {

	CPUCycle cycles(&total_cpu_cycles_);
	int length = flow->packet->getLength();
	total_bytes_ += length;
	++total_packets_;

        current_flow_ = flow;

        if (length >= header_size) {
		SharedPointer<QuicInfo> info = flow->getQuicInfo();
		if (!info) {
			if (info = info_cache_->acquire(); !info) {
				logFailCache(info_cache_->name(), flow);
				return;
			}
			flow->layer7info = info;
		}

		// The first packet is the client hello
		if (flow->total_packets[0] + flow->total_packets[1]== 1) {
			if (length > (int)sizeof(quic_header) + (int)sizeof(quic_frame) + 12 + 1) {
				const uint8_t *payload = flow->packet->getPayload();
				int offset = sizeof(quic_header) + 12 + 1;
				const quic_frame *frame = reinterpret_cast<const quic_frame*>(&payload[offset]);

				if ((frame->tag[0] == 'C')and(frame->tag[1] == 'H')and(frame->tag[2] == 'L')and(frame->tag[3] == 'O')) {
					offset += sizeof(quic_frame);
					handle_client_hello(info.get(), &payload[offset + 1], frame->length, frame->tags);
				}
			}
                        if ((domain_mng_)and(info->server_name)) {
                                if (auto host_candidate = domain_mng_->getDomainName(info->server_name->name()); host_candidate) {
                                        ++ total_events_;
                                        info->matched_domain_name = host_candidate;
#if defined(BINDING)
                                        if (host_candidate->call.haveCallback())
                                                host_candidate->call.executeCallback(flow);
#endif
                                }
                        }

		}
        }
}

void QuicProtocol::statistics(std::basic_ostream<char> &out, int level, int32_t limit) const {

	showStatisticsHeader(out, level);

        if ((level > 5)and(flow_forwarder_.lock()))
		flow_forwarder_.lock()->statistics(out);
	if (level > 3) {
		info_cache_->statistics(out);
		host_cache_->statistics(out);
		ua_cache_->statistics(out);
		if (level > 4) {
			host_map_.show(out, "\t", limit);
			ua_map_.show(out, "\t", limit);
                }
        }
}

void QuicProtocol::statistics(Json &out, int level) const {

	showStatisticsHeader(out, level);
}

CounterMap QuicProtocol::getCounters() const {
	CounterMap cm;

        cm.addKeyValue("packets", total_packets_);
        cm.addKeyValue("bytes", total_bytes_);

        return cm;
}

void QuicProtocol::increaseAllocatedMemory(int value) {

        info_cache_->create(value);
        host_cache_->create(value);
        ua_cache_->create(value);
}

void QuicProtocol::decreaseAllocatedMemory(int value) {

        info_cache_->destroy(value);
        host_cache_->destroy(value);
        ua_cache_->destroy(value);
}

void QuicProtocol::statistics(Json &out, const std::string &map_name, int32_t limit) const {

        if (boost::iequals(map_name, "hosts")) {
                host_map_.show(out, limit);
		return;
        }
        if (boost::iequals(map_name, "useragents")) {
                ua_map_.show(out, limit);
        }
}

void QuicProtocol::resetCounters() {

	reset();

	total_events_ = 0;
}

} // namespace aiengine
